{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Notebook Setup \n",
    "The following cell will install Drake, checkout the underactuated repository, and set up the path (only if necessary).\n",
    "- On Google's Colaboratory, this **will take approximately two minutes** on the first time it runs (to provision the machine), but should only need to reinstall once every 12 hours.  Colab will ask you to \"Reset all runtimes\"; say no to save yourself the reinstall.\n",
    "- On Binder, the machines should already be provisioned by the time you can run this; it should return (almost) instantly.\n",
    "\n",
    "More details are available [here](http://underactuated.mit.edu/drake.html)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "try:\n",
    "  import pydrake\n",
    "  import underactuated\n",
    "except ImportError:\n",
    "  !curl -s https://raw.githubusercontent.com/RussTedrake/underactuated/master/scripts/setup/jupyter_setup.py > jupyter_setup.py\n",
    "  from jupyter_setup import setup_underactuated\n",
    "  setup_underactuated()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Energy Shaping Controller\n",
    "\n",
    "First we will design the energy shaping controller (only), and plot the closed-loop phase portrait.  Remember, this system is not actually stable at the upright.  It is only attractive!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from copy import copy\n",
    "\n",
    "from pydrake.all import (DiagramBuilder, Saturation, SignalLogger,\n",
    "                         Simulator, wrap_to, VectorSystem)\n",
    "from pydrake.examples.pendulum import (PendulumParams, PendulumPlant)\n",
    "from underactuated.jupyter import SetupMatplotlibBackend\n",
    "plt_is_interactive = SetupMatplotlibBackend()\n",
    "\n",
    "\n",
    "class EnergyShapingController(VectorSystem):\n",
    "\n",
    "    def __init__(self, pendulum):\n",
    "        VectorSystem.__init__(self, 2, 1)\n",
    "        self.pendulum = pendulum\n",
    "        self.pendulum_context = pendulum.CreateDefaultContext()\n",
    "        self.SetPendulumParams(PendulumParams())\n",
    "        \n",
    "    def SetPendulumParams(self, params):\n",
    "      self.pendulum_context.get_mutable_numeric_parameter(0).SetFromVector(params.CopyToVector())\n",
    "      self.pendulum_context.SetContinuousState([np.pi, 0])\n",
    "      self.desired_energy = self.pendulum.EvalPotentialEnergy(self.pendulum_context)\n",
    "\n",
    "    def DoCalcVectorOutput(self, context, pendulum_state, unused, output):\n",
    "      self.pendulum_context.SetContinuousState(pendulum_state)\n",
    "      params = self.pendulum_context.get_numeric_parameter(0)\n",
    "      theta = pendulum_state[0]\n",
    "      thetadot = pendulum_state[1]\n",
    "      total_energy = self.pendulum.EvalPotentialEnergy(self.pendulum_context) + self.pendulum.EvalKineticEnergy(self.pendulum_context)\n",
    "      output[:] = (params.damping() * thetadot - .1 * thetadot *\n",
    "                         (total_energy - self.desired_energy))\n",
    "\n",
    "\n",
    "def PhasePlot(pendulum):\n",
    "    phase_plot = plt.figure()\n",
    "    ax = phase_plot.gca()\n",
    "    theta_lim = [-np.pi, 3. * np.pi]\n",
    "    ax.set_xlim(theta_lim)\n",
    "    ax.set_ylim(-10., 10.)\n",
    "\n",
    "    theta = np.linspace(theta_lim[0], theta_lim[1], 601)  # 4*k + 1\n",
    "    thetadot = np.zeros(theta.shape)\n",
    "    context = pendulum.CreateDefaultContext()\n",
    "    params = context.get_numeric_parameter(0)\n",
    "    context.SetContinuousState([np.pi, 0])\n",
    "    E_upright = pendulum.EvalPotentialEnergy(context)\n",
    "    E = [E_upright, .1 * E_upright, 1.5 * E_upright]\n",
    "    for e in E:\n",
    "        for i in range(theta.size):\n",
    "            v = ((e + params.mass() * params.gravity() * params.length() *\n",
    "                  np.cos(theta[i])) /\n",
    "                 (.5 * params.mass() * params.length() * params.length()))\n",
    "            if (v >= 0):\n",
    "                thetadot[i] = np.sqrt(v)\n",
    "            else:\n",
    "                thetadot[i] = float(\"nan\")\n",
    "        ax.plot(theta, thetadot, color=[.6, .6, .6])\n",
    "        ax.plot(theta, -thetadot, color=[.6, .6, .6])\n",
    "\n",
    "    return ax\n",
    "\n",
    "\n",
    "builder = DiagramBuilder()\n",
    "\n",
    "pendulum = builder.AddSystem(PendulumPlant())\n",
    "ax = PhasePlot(pendulum)\n",
    "saturation = builder.AddSystem(Saturation(min_value=[-3], max_value=[3]))\n",
    "builder.Connect(saturation.get_output_port(0), pendulum.get_input_port(0))\n",
    "controller = builder.AddSystem(EnergyShapingController(pendulum))\n",
    "builder.Connect(pendulum.get_output_port(0), controller.get_input_port(0))\n",
    "builder.Connect(controller.get_output_port(0), saturation.get_input_port(0))\n",
    "\n",
    "# visualizer = builder.AddSystem(PendulumVisualizer())\n",
    "# builder.Connect(pendulum.get_output_port(0), visualizer.get_input_port(0))\n",
    "\n",
    "logger = builder.AddSystem(SignalLogger(2))\n",
    "builder.Connect(pendulum.get_output_port(0), logger.get_input_port(0))\n",
    "\n",
    "diagram = builder.Build()\n",
    "simulator = Simulator(diagram)\n",
    "context = simulator.get_mutable_context()\n",
    "\n",
    "for i in range(5):\n",
    "    context.SetTime(0.)\n",
    "    context.SetContinuousState(np.random.randn(2,))\n",
    "    simulator.AdvanceTo(4)\n",
    "    ax.plot(logger.data()[0, :], logger.data()[1, :])\n",
    "    logger.reset()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Swing-up and balance\n",
    "\n",
    "Now we will combine our simple energy shaping controller with a linear controller that stabilizes the upright fixed point once we get close enough.  We'll read more about this approach in the Acrobot and Cart-Pole notes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pydrake.all import Linearize, LinearQuadraticRegulator\n",
    "\n",
    "\n",
    "def BalancingLQR(pendulum):\n",
    "    context = pendulum.CreateDefaultContext()\n",
    "\n",
    "    context.FixInputPort(0, [0])\n",
    "    context.SetContinuousState([np.pi, 0])\n",
    "\n",
    "    Q = np.diag((10., 1.))\n",
    "    R = [1]\n",
    "\n",
    "    linearized_pendulum = Linearize(pendulum, context)\n",
    "    (K, S) = LinearQuadraticRegulator(linearized_pendulum.A(),\n",
    "                                      linearized_pendulum.B(), Q, R)\n",
    "    return (K, S)\n",
    "\n",
    "  \n",
    "class SwingUpAndBalanceController(VectorSystem):\n",
    "\n",
    "    def __init__(self, pendulum):\n",
    "        VectorSystem.__init__(self, 2, 1)\n",
    "        (self.K, self.S) = BalancingLQR(pendulum)\n",
    "        self.energy_shaping = EnergyShapingController(pendulum)\n",
    "        self.energy_shaping_context = self.energy_shaping.CreateDefaultContext()\n",
    "\n",
    "        # TODO(russt): Add a witness function to tell the simulator about the\n",
    "        # discontinuity when switching to LQR.\n",
    "\n",
    "    def DoCalcVectorOutput(self, context, pendulum_state, unused, output):\n",
    "        xbar = copy(pendulum_state)\n",
    "        xbar[0] = wrap_to(xbar[0], 0, 2. * np.pi) - np.pi\n",
    "\n",
    "        # If x'Sx <= 2, then use the LQR controller\n",
    "        if (xbar.dot(self.S.dot(xbar)) < 2.):\n",
    "            output[:] = -self.K.dot(xbar)\n",
    "        else:\n",
    "            self.energy_shaping_context.FixInputPort(0, pendulum_state)\n",
    "            output[:] = self.energy_shaping.get_output_port(0).Eval(self.energy_shaping_context)\n",
    "            \n",
    "builder = DiagramBuilder()\n",
    "\n",
    "pendulum = builder.AddSystem(PendulumPlant())\n",
    "ax = PhasePlot(pendulum)\n",
    "saturation = builder.AddSystem(Saturation(min_value=[-3], max_value=[3]))\n",
    "builder.Connect(saturation.get_output_port(0), pendulum.get_input_port(0))\n",
    "controller = builder.AddSystem(SwingUpAndBalanceController(pendulum))\n",
    "builder.Connect(pendulum.get_output_port(0), controller.get_input_port(0))\n",
    "builder.Connect(controller.get_output_port(0), saturation.get_input_port(0))\n",
    "\n",
    "logger = builder.AddSystem(SignalLogger(2))\n",
    "builder.Connect(pendulum.get_output_port(0), logger.get_input_port(0))\n",
    "\n",
    "diagram = builder.Build()\n",
    "simulator = Simulator(diagram)\n",
    "context = simulator.get_mutable_context()\n",
    "\n",
    "for i in range(5):\n",
    "    context.SetTime(0.)\n",
    "    context.SetContinuousState(np.random.randn(2,))\n",
    "    simulator.AdvanceTo(4)\n",
    "    ax.plot(logger.data()[0, :], logger.data()[1, :])\n",
    "    logger.reset()\n",
    "    \n",
    "ax.set_xlim(np.pi - 3., np.pi + 3.)\n",
    "ax.set_ylim(-5., 5.)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
